# -*- coding: utf-8 -*-
import torch
from math import cos, sin, pi
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.autograd.set_detect_anomaly(True)
import os
import torch


class Env:
    def __init__(self, stype, complexity):
        self.type = stype 

        loaded_tensor = torch.load('main experiment/Parameter/circle.pt', map_location='cpu', pickle_module=torch.serialization.pickle)
        self.my_circle = loaded_tensor[complexity][self.type]
        self.hill_list = None
        self.var_list = None

    def gen(self):
        
        hillList = self.my_circle[:, :2]
        varList = self.my_circle[:, -1].unsqueeze(-1)
        self.hill_list = hillList.to(device)/5
        self.var_list = varList.to(device)/20

        rows = torch.arange(-50, 51)
        cols = torch.arange(-50, 51)

        row_indices, col_indices = torch.meshgrid(rows, cols, indexing="ij")
        indices_tensor = torch.stack((row_indices, col_indices), dim = 2)
 
        basemap = indices_tensor

        sumweight = torch.zeros_like(basemap[:, :, 0]).unsqueeze(-1)

        for i in range(len(self.my_circle)):
            squaresum = torch.sum(torch.square(basemap - hillList[i])/varList[i], -1)
            sumweight = sumweight + torch.exp(- squaresum).unsqueeze(-1) # mapping to 0 ~ 1

        sumweight = torch.clamp(sumweight, max = 0.99).squeeze()
        sumweight = sumweight.transpose(0, 1)

        return (sumweight, self.hill_list, self.var_list)

    def c_d_r(self, target):

        if len(self.hill_list) == 0:
            return 0
        # target (N, 2), meanlist (M, 2), varlist (M, 2)

        # Expand target and meanlist for broadcasting
        target_expanded = target.unsqueeze(1)  # Shape becomes (N, 1, 2)
        meanlist_expanded = self.hill_list.unsqueeze(0)  # Shape becomes (1, M, 2)
        varlist_expanded = self.var_list.unsqueeze(0)  # Shape becomes (1, M, 2)

        # Calculate squared differences and normalize by variance
        diff = target_expanded - meanlist_expanded
        weighted_diff = torch.square(diff) / varlist_expanded

        # Sum along the last dimension and exponentiate
        sum_weighted_diff = torch.sum(weighted_diff, dim=-1)
        exp_weights = torch.exp(-sum_weighted_diff)

        # Sum across all mean points (M dimension)
        sumweight = torch.sum(exp_weights, dim=1)  # Resulting shape (N,)

        # No need to unsqueeze, as we are summing over M, and retaining the N dimension
        return sumweight*0.99

    def step(self, state, action):
        
        tmp_coord = state[:, :2] + action
        first_d_rate = self.c_d_r(state[:, :2])
        sec_d_rate = self.c_d_r(tmp_coord[:, :2])

        d_rate = (first_d_rate.unsqueeze(-1) + sec_d_rate.unsqueeze(-1))/2
        

        if self.type == -1:
            post_coord = state[:, :2] + action
        else:
            post_coord = state[:, :2] + action*(1-d_rate)

        return post_coord